import argparse
import os
import tensorflow as tf
from tensorflow.keras import mixed_precision
from transformers import OpenAIGPTTokenizer
import time
from model import GPT1

# Prepare the dataset in tensorflow dataset format for training
def prepare_dataset(data_dir, batch_size):
    feature_description = {
        'token_ids': tf.io.FixedLenFeature([513], tf.int64)
    }

    def _parse_function(example_proto):
        # Parse the input `tf.Example` proto using the dictionary above.
        return tf.io.parse_single_example(example_proto, feature_description)

    filenames = os.listdir(data_dir)
    filenames = [data_dir+f for f in filenames]
    tf_ds = tf.data.TFRecordDataset(filenames)
    tf_ds = tf_ds\
        .map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE)\
        .shuffle(buffer_size=batch_size*100)\
        .prefetch(tf.data.experimental.AUTOTUNE)\
        .batch(batch_size)
    
    return tf_ds

def loss_function(real, pred):
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    loss_ = loss_object(real, pred)
    return tf.reduce_mean(loss_)

def accuracy_function(real, pred):
    accuracies = tf.equal(real, tf.argmax(pred, axis=2))
    #mask = tf.math.logical_not(tf.math.equal(real, 0))
    #accuracies = tf.math.logical_and(mask, accuracies)
    accuracies = tf.cast(accuracies, dtype=tf.float32)
    return tf.reduce_mean(accuracies)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='gpt1')
    parser.add_argument('--resume', type=str, help='Specify the CKPT name for resume training')
    parser.add_argument('--start_epoch', type=int, default=0, help='If resume training then specify the epoch to continue')
    parser.add_argument('--num_epoch', type=int, default=1, help='Specify the number of epochs to train')
    parser.add_argument('--steps_epoch', type=int, default=5000, help='Specify the steps of epoch')
    parser.add_argument('--total_epochs', type=int, default=30, help='Specify the total target epochs to train')
    parser.add_argument('--no_mixed', action='store_true', help='Specify this to not use mixed precesion to train')
    parser.add_argument('--dataset', type=str, default='dataset/', help='The tfrecord format dataset path')
    parser.add_argument('--block_size', type=int, default=512, help='The sequence lenght of the tokens for trianing')
    parser.add_argument('--decoder_layers', type=int, default=6, help='Decoder layers, orginial gpt1 model contains 12 layers')
    parser.add_argument('--heads', type=int, default=12, help='Multi attention heads per decoder layer')
    parser.add_argument('--d_model', type=int, default=768, help='Feature dimension for the total multi attention heads')
    parser.add_argument('--dff', type=int, default=3072, help='Feed forward layer feature dimension')
    parser.add_argument('--batch_size', type=int, default=16, help='Training batch size, original model use 64')
    parser.add_argument('--attn_pdrop', type=float, default=0.1)
    parser.add_argument('--resid_pdprop', type=float, default=0.1)
    parser.add_argument('--embed_pdrop', type=float, default=0.1)
    parser.add_argument('--learning_rate', type=float, default=0.0006, help='Original gpt1 use 0.00025')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay')
    parser.add_argument('--warmup_steps', type=int, default=2000)
    parser.add_argument('--checkpoint_path', type=str, default='./checkpoints/train')
    parser.add_argument('--logfile', type=str, default='train_result.txt')
    args = parser.parse_args()

    mixed = False
    if not args.no_mixed:
        mixed_precision.set_global_policy('mixed_float16')
        mixed = True
        print("Use Mixed Precision to train.")

    tf_ds = prepare_dataset(args.dataset, args.batch_size)

    tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
    vocab_size = len(tokenizer.get_vocab())

    model = GPT1(
        num_layers=args.decoder_layers, 
        d_model=args.d_model, 
        num_heads=args.heads, 
        dff=args.dff, 
        target_vocab_size=vocab_size+args.block_size, 
        block_size=args.block_size, 
        initializer=tf.random_normal_initializer(stddev=0.02), 
        mixed=mixed, 
        scale=True,
        attn_pdrop=args.attn_pdrop,
        resid_pdrop=args.resid_pdprop,
        embed_pdrop=args.embed_pdrop)
    
    decay_steps = args.steps_epoch*args.total_epochs
    initial_learning_rate = 0.
    lr_warmup_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate, 
        decay_steps, 
        warmup_target=args.learning_rate,
        warmup_steps=args.warmup_steps)

    optimizer = tf.keras.optimizers.AdamW(lr_warmup_decayed_fn, beta_1=0.9, beta_2=0.999, epsilon=1e-8, weight_decay=args.weight_decay)
    if mixed:
        optimizer = mixed_precision.LossScaleOptimizer(optimizer)

    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')
    
    #Define two trackable object to save
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt, args.checkpoint_path, max_to_keep=5)
 
    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print('Latest checkpoint restored!!')


    train_step_signature = [
        tf.TensorSpec(shape=(None, None), dtype=tf.int64),
        tf.TensorSpec(shape=(None, None), dtype=tf.int64)
    ]
    @tf.function(input_signature=train_step_signature)
    def train_step(inp, tar):
        with tf.GradientTape() as tape:
            predictions, _ = model(inp, training = True)
            #loss = loss_function(tar, predictions)
            loss = loss_object(tar, predictions)
            scaled_loss = optimizer.get_scaled_loss(loss)
    
        if mixed:
            #scaled_loss = optimizer.get_scaled_loss(loss)
            scaled_gradients = tape.gradient(scaled_loss, model.trainable_variables)
            gradients = optimizer.get_unscaled_gradients(scaled_gradients)
        else:
            gradients = tape.gradient(loss, model.trainable_variables)
        
        #optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
        optimizer.apply_gradients((grad, var) for (grad, var) in zip(gradients, model.trainable_variables) if grad is not None)
    
        train_loss(loss)
        accuracy = tf.cast(tf.equal(tar, tf.argmax(predictions, axis=2)), dtype=tf.float32)
        train_accuracy(accuracy)
        #train_accuracy(accuracy_function(tar, predictions))   

    start_batch = args.start_epoch * args.steps_epoch
    for epoch in range(args.start_epoch, args.start_epoch+args.num_epoch):
        start = time.time()
    
        train_loss.reset_states()
        train_accuracy.reset_states()
        #total_loss = 0
        #total_accuracy = 0
        for (batch, inputs) in enumerate(tf_ds):
            try:
                train_step(inputs['token_ids'][...,:-1], inputs['token_ids'][...,1:])
            except ValueError:
                print(inputs)
                break

            if batch % 100 == 0:
                line = f'Epoch {epoch + 1} Batch {batch+start_batch} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f} Learning rate {optimizer.lr.numpy():.5f}'
                with open(args.logfile, 'a') as logfile:
                    logfile.write(line+'\n')
                print(line)
                train_loss.reset_states()
                train_accuracy.reset_states()
                if batch >= args.steps_epoch:
                    break
        
        start_batch += args.steps_epoch
        ckpt_save_path = ckpt_manager.save()
        with open(args.logfile, 'a') as logfile:
            line = f'Saving checkpoint for epoch {epoch+1} at {ckpt_save_path}'
            logfile.write(line+'\n')
            print(line)
            #line = f'Epoch {epoch + 1} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}'
            #logfile.write(line+'\n')
            #print(line)
            line = f'Time taken for 1 epoch: {time.time() - start:.2f} secs\n'
            logfile.write(line+'\n')
            print(line)

        model.save('saved_model/gpt1_model')